Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Initial draft implementation of CFG for LLM #996

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ottonemo
Copy link
Member

@ottonemo ottonemo commented Jul 19, 2023

Based on the paper

    Sanchez, Guillaume, et al.
    "Stay on topic with Classifier-Free Guidance."
    arXiv preprint arXiv:2306.17806 (2023).

a draft implementation of classifier free guidance.

This is simply for sharing internally and might very well be completely wrong. It is debatable if we should expose such a feature as a flag to the network or make it a separate classifier instance (or a mixin). In the past we were very much against special (potentially short-lived) feature flags and it was much nicer to have this implemented as an addon/callback. We might need to do something similar here as well.

Open tasks:

  • evaluate existing examples
  • write explicit test cases

Based on the paper

        Sanchez, Guillaume, et al.
        "Stay on topic with Classifier-Free Guidance."
        arXiv preprint arXiv:2306.17806 (2023).

a draft implementation of classifier free guidance.

This is simply for sharing internally and might very well be
completely wrong. It is debatable if we should expose such
a feature as a flag to the network or make it a separate
classifier instance (or a mixin). In the past we were
very much against special (potentially short-lived) feature
flags and it was much nicer to have this implemented as
an addon/callback. We might need to do something similar
here as well.
@BenjaminBossan
Copy link
Collaborator

The paper in question is this one:

https://arxiv.org/abs/2306.17806

Note that this method should have a greater effect the longer the labels are.

Some random comments:

  • At the moment, two forward passes are needed. Shouldn't we be able to pre-compute (or cache) P_wi_wji, since labels are always the same and known from the start?
  • How about, instead of exposing use_cfg, we expose cfg_gamma. If it is 1 (or None), don't use CFG, else apply that gamma instead of basically hard-coding it to 1.5?

It is debatable if we should expose such a feature as a flag to the network or make it a separate classifier instance (or a mixin). In the past we were very much against special (potentially short-lived) feature flags and it was much nicer to have this implemented as an addon/callback.

If this method works really well, I can see it being added explicitly. Alternatively, we could have a callbacks equivalent for logits processors, with _LogitsRecorder being the default.

- Makes it possible to set gamma parameter
- Setting it to `None` disabled functionality completely
@ottonemo ottonemo force-pushed the feature/llm-classifier-free-guidance branch from 96de091 to 1c34aca Compare July 19, 2023 18:45
- `label_id` was misleading since it is actually a list of token ids
  related to a label and not a scalar value. Also the general process
  of generating logits it not related to labels at all but rather just
  to tokens

- `kwargs` was named to be similar to transformers `generate`
  convention but is meant to be passed to `generate` and is therefore,
  in the context of `generate_logits` a model input. This should help
  the reader distinguish between expected input (`token_ids`) and
  model input (`model_input`)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants